import os
import re
import sys
import glob
import pickle
import ujson
import datetime
import warnings
import collections

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.dates as mdates
from pandas import to_datetime

import scipy.stats as stats
import statsmodels.api as sm
import statsmodels.formula.api as smf

from pmdarima.arima import auto_arima, ndiffs, nsdiffs
from itertools import combinations
from tld import get_tld
import urlexpander
from tqdm import tqdm

import warnings
# Suppress FutureWarnings globally, including for sklearn and specific pmdarima messages
os.environ['PYTHONWARNINGS'] = 'ignore::FutureWarning'
sys.warnoptions = []

warnings.simplefilter('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=FutureWarning, module='sklearn')
warnings.filterwarnings('ignore', message=r"m \(\d+\) set for non-seasonal fit. Setting to 0", module='pmdarima')


# --- ITS Data Preparation ---
def get_users_agg_data_its(users_data_agg, policy_week):
    """
    Prepare user data for Interrupted Time Series (ITS) analysis, without altering the original data.

    Parameters:
    users_data_agg (DataFrame): DataFrame with user data including 'week_old'.
    policy_week (int): The policy intervention week as a date.

    Returns:
    DataFrame: The modified DataFrame ready for ITS analysis, indexed by 'week_start_date'.
    """
    # Create a copy of the DataFrame to avoid modifying the original
    data = users_data_agg.copy()

    # Map week_old to week_start_date using a predefined dictionary week_breaks_index
    data['week_start_date'] = data['week_old'].map(week_breaks_index)

    # Create a binary intervention column
    data['intervention'] = (data['week_start_date'] >= policy_week).astype(int)

    # Offset week_old by 1 for analysis
    data['week'] = data['week_old'] + 1

    # Calculate the weeks since intervention
    data['intervention_week'] = (data['week_start_date'] >= policy_week).cumsum()

    # Set 'week_start_date' as the index without altering the original DataFrame
    data.set_index('week_start_date', inplace=True)

    # Select only the necessary columns for output
    return data[['y', 'week', 'intervention', 'intervention_week']]


# --- OLS Model for ITS ---
def do_ols_its(data):

    model = smf.ols(formula='y ~ week + intervention + intervention_week', data=data)
    res = model.fit()
   
    return res


# --- ARIMA Model ---
def do_arima(data, ols_results, m=None, seasonal=True):
    """
    Fit an ARIMA model to the data with optional seasonal component.

    Parameters:
    - data (DataFrame): Input time series data.
    - ols_results (RegressionResults): OLS results used to determine differencing orders.
    - m (int): The seasonal periodicity (e.g., 12 for monthly data with yearly seasonality).
    - seasonal (bool): Whether to include seasonal differencing and seasonal terms.

    Returns:
    - out_model: Fitted ARIMA model.
    """
    d = max(
        ndiffs(ols_results.resid, max_d=10),
        ndiffs(ols_results.resid, max_d=10, test='adf'),
        ndiffs(ols_results.resid, max_d=10, test='pp')
    )
    D = max(
        nsdiffs(ols_results.resid, max_D=10, m=m),
        nsdiffs(ols_results.resid, max_D=10, m=m, test='ch')
    )

    in_data = data['y']
    xreg = data[['week', 'intervention', 'intervention_week']]

    out_model = auto_arima(
        in_data,
        d=d,
        D=D,
        X=xreg,
        m=m,
        seasonal=seasonal,
        stepwise=False,
        n_jobs=-1,
        method='powell',
        random_state=1,
        suppress_warnings=True,
        error_action='ignore'
    )
    
    return out_model

def do_arima_plot(data, arima_results, policy_week, name, m=None, seasonal=True):
    """
    Plot ARIMA results with optional seasonal modeling for counterfactuals.

    Parameters:
    - data (DataFrame): Time series data.
    - arima_results: Fitted ARIMA model.
    - policy_week (datetime): Date marking the policy intervention.
    - m (int): Seasonal periodicity.
    - name (str): Label for y-axis and output filename.
    - seasonal (bool): Whether to use seasonal ARIMA for counterfactual.
    """
    warnings.simplefilter(action='ignore', category=RuntimeWarning)

    start = data[data.index > policy_week].week.values[0]
    end = data.tail(1).week.values[0]

    y_pred = pd.DataFrame(arima_results.arima_res_.fittedvalues, columns=['y'])

    # Counterfactual model
    n_diffs = ndiffs(data[data.index <= policy_week]['y'], test='adf')
    n_Diffs = nsdiffs(data[data.index <= policy_week]['y'], m=m, max_D=3, test='ch')

    arima_cf = auto_arima(
        data[data.index <= policy_week]['y'],
        seasonal=seasonal,
        m=m,
        d=n_diffs,
        D=n_Diffs,
        start_p=0, start_q=0, max_p=5, max_q=5,
        start_P=0, start_Q=0, max_P=5, max_Q=5,
        trace=False,
        error_action='ignore',
        suppress_warnings=True,
        n_jobs=-1,
        stepwise=False
    )

    nsim = 1000
    sim_1000, ci_lower_1000, ci_upper_1000 = [], [], []

    for _ in range(nsim):
        tmp_y_cf, tmp_confint = arima_cf.predict(
            n_periods=len(data[data.index > policy_week]['y']),
            return_conf_int=True,
            type='levels',
            index=data[data.index > policy_week].index
        )
        sim_1000.append(tmp_y_cf)
        ci_lower_1000.append(tmp_confint[:, 0])
        ci_upper_1000.append(tmp_confint[:, 1])

    y_cf = np.mean(sim_1000, axis=0)
    ci_lower = np.mean(ci_lower_1000, axis=0)
    ci_upper = np.mean(ci_upper_1000, axis=0)

    # Plotting
    # plt.style.use('seaborn-whitegrid')
    fig, ax = plt.subplots(figsize=(30, 8))

    ax.plot(data[data.index <= policy_week]["week"][1:], y_pred[y_pred.index <= policy_week]["y"][1:],
            color='black', linestyle='--', linewidth=3, label='Model prediction')

    ax.plot(data["week"], data["y"],
            color='gray', linewidth=3, label=f'Observed {name}')

    ax.plot(data[data.index > policy_week]["week"], y_cf,
            color='dimgray', marker='.', linestyle='None', markersize=10,
            label="Counterfactual")

    ax.axvline(x=policy_index + 1.5, color='darkgray', linewidth=3, label='Proud Boys Ban')

    first_weekly_dates_of_month = data.index[
        data.index.to_series().dt.to_period("M") != data.index.to_series().dt.to_period("M").shift()]
    show_x_ticklabels = np.array(data[data.index.isin(first_weekly_dates_of_month)].week.tolist())

    ax.xaxis.set_ticks(show_x_ticklabels)
    ax.set_xticklabels(data[data.week.isin(show_x_ticklabels)].index.strftime('%Y-%m').values,
                       rotation=45, fontsize=20)

    ax.yaxis.set_tick_params(labelsize=20)
    ax.legend(loc='upper left', fontsize=20)
    ax.set_ylabel(name, fontsize=20)
    ax.title.set_size(30)

    fig.tight_layout()

    plt.savefig(its_path_result + f'its_results/its_plot_{name}.png')
    plt.show()


# -- Calculate Effect Size ---
def calculate_effect_size(data, arima_results, nsim, policy_week, name, m=None, seasonal=True):
    """
    Calculate the effect size from ARIMA simulation with optional seasonality.

    Parameters:
    - data: DataFrame with time series
    - arima_results: Fitted ARIMA model
    - nsim: Number of simulations
    - policy_week: Date of intervention
    - m: Seasonal period
    - name: Variable name for output
    - seasonal: Whether to use seasonal ARIMA
    """
    # Fit counterfactual model
    n_diffs = ndiffs(data[data.index <= policy_week]['y'], test='adf')
    n_Diffs = nsdiffs(data[data.index <= policy_week]['y'], m=m, max_D=3, test='ch')

    arima_cf = auto_arima(
        data[data.index <= policy_week]['y'],
        seasonal=seasonal,
        m=m,
        d=n_diffs,
        D=n_Diffs,
        start_p=0, start_q=0, max_p=5, max_q=5,
        start_P=0, start_Q=0, max_P=5, max_Q=5,
        trace=False,
        error_action='ignore',
        suppress_warnings=True,
        n_jobs=-1,
        stepwise=False
    )

    sim, sim_1000 = [], []
    for _ in range(nsim):
        simulated = arima_cf.arima_res_.simulate(data[data.index > policy_week].shape[0], anchor='end')
        sim.append(np.mean(simulated))
        sim_1000.append(simulated)

    actual = np.mean(data[data.index > policy_week]['y'])
    y_effect_mean_diff = actual - np.mean(sim)
    y_effect_std = np.std(sim)
    y_effect_lower_ci = y_effect_mean_diff - 1.96 * y_effect_std
    y_effect_upper_ci = y_effect_mean_diff + 1.96 * y_effect_std
    y_effect_zscore = y_effect_mean_diff / y_effect_std
    y_effect_pvalue = stats.norm.sf(abs(y_effect_zscore)) * 2

    effect_table = pd.DataFrame({
        'Variable': [name],
        'Mean Difference': [round(y_effect_mean_diff, 3)],
        'Std': [round(y_effect_std, 3)],
        'Lower CI': [round(y_effect_lower_ci, 3)],
        'Upper CI': [round(y_effect_upper_ci, 3)],
        'Z-score': [round(y_effect_zscore, 3)],
        'P-value': [round(y_effect_pvalue, 4)]
    })

    effect_table.to_csv(its_path_result + f"its_results/its_effect_size_{name}.csv", index=False)
    display(effect_table)


def get_rescaled_df(df, col):

    min_value = df[col].min()
    max_value = df[col].max()

    df['y'] = (df[col] - min_value) / (max_value - min_value)
    
    return df


# --- Paths & Constants ---
its_path_main = "./Replication_Materials/data/"
its_path_result = "./Results/"

policy_week = datetime.datetime(2018, 8, 10, 0, 0, 0)
policy_index = 53

with open(its_path_main +"week_breaks_index.pkl", "rb") as f:
    week_breaks_index = pickle.load(f)


# --- Get Results --- 
def run_its_co_sharing_analysis(
    file_path,
    input_col,
    output_label,
    policy_week,
    m,
    nsim=1000,
    file_type="csv"  # or "pickle"
):
    """
    Run ITS + ARIMA analysis on co-sharing density (hashtags or domains).

    Parameters:
    - file_path (str): Path to input file (.csv or .pkl).
    - input_col (str): Column name for the co-sharing density.
    - output_label (str): Label for plots and output files.
    - policy_week (datetime): Intervention week.
    - m (int): Seasonal period for ARIMA.
    - nsim (int): Number of simulations for effect size.
    - file_type (str): "csv" or "pickle"
    """
    if file_type == "csv":
        df = pd.read_csv(file_path)
    elif file_type == "pickle":
        with open(file_path, "rb") as f:
            df = pickle.load(f)
    else:
        raise ValueError("file_type must be 'csv' or 'pickle'")

    df = df[['week', input_col]].rename(columns={'week': 'week_old', input_col: 'y'})
    df = get_rescaled_df(df, 'y')

    its_data = get_users_agg_data_its(users_data_agg=df, policy_week=policy_week)
    ols_res = do_ols_its(its_data)
    arima_res = do_arima(data=its_data, ols_results=ols_res, m=m)

    do_arima_plot(data=its_data, arima_results=arima_res, policy_week=policy_week, m=m, name=output_label)
    calculate_effect_size(data=its_data, arima_results=arima_res, nsim=nsim, policy_week=policy_week, m=m, name=output_label)


# Hashtag co-sharing network
run_its_co_sharing_analysis(
    file_path=its_path_main + "weekly_density_cohashtag.csv",
    input_col="pb_hashtag_co_sharing_G_using_OP_density",
    output_label="Coshare hashtag network density",
    policy_week=policy_week,
    m=5,
    file_type="csv"
)

# Domain co-sharing network (from pickle)
run_its_co_sharing_analysis(
    file_path=its_path_main + "weekly_density_codomain.csv",
    input_col="pb_domain_co_sharing_G_using_OP_density",
    output_label="Coshare URLs network density",
    policy_week=policy_week,
    m=5,
    file_type="csv"
)


def run_its_retweet_analysis(
    file_path,
    input_col,
    output_label,
    policy_week,
    m,
    seasonal=False,
    nsim=1000
):
    """
    Perform ITS + ARIMA analysis on retweet network density data.

    Parameters:
    - file_path (str): Path to CSV data file.
    - input_col (str): Column name for the y-variable (e.g. density).
    - output_label (str): Label for output plots/files.
    - policy_week (datetime): Week of intervention.
    - m (int): ARIMA seasonal period.
    - seasonal (bool): Whether to include seasonality in ARIMA.
    - nsim (int): Number of simulations for effect size.
    """
    df = pd.read_csv(file_path)[['week', input_col]].rename(
        columns={'week': 'week_old', input_col: 'y'}
    )
    df = get_rescaled_df(df, 'y')
    
    its_data = get_users_agg_data_its(users_data_agg=df, policy_week=policy_week)
    ols_res = do_ols_its(its_data)
    arima_res = do_arima(data=its_data, ols_results=ols_res, m=m, seasonal=seasonal)
    
    do_arima_plot(
        data=its_data, arima_results=arima_res, policy_week=policy_week,
        m=m, name=output_label, seasonal=seasonal
    )
    calculate_effect_size(
        data=its_data, arima_results=arima_res, nsim=nsim,
        policy_week=policy_week, m=m, name=output_label, seasonal=seasonal
    )


# (1) All users
run_its_retweet_analysis(
    file_path=its_path_main + "weekly_density_retweet_all.csv",
    input_col='density_retweet',
    output_label='Retweet network density',
    policy_week=policy_week,
    m=5,
    seasonal=True
)

# (2) Broader audience (non-Proud Boys followers)
run_its_retweet_analysis(
    file_path=its_path_main + "weekly_density_retweet_non_pb.csv",
    input_col='density_retweet_broader_audience',
    output_label='Retweet network density (with non-Proud Boys followers)',
    policy_week=policy_week,
    m=5,
    seasonal=True
)

# (3) Insiders (Proud Boys followers)
run_its_retweet_analysis(
    file_path=its_path_main + "weekly_density_retweet_pb.csv",
    input_col='density_retweet_insiders',
    output_label='Retweet network density (with Proud Boys followers)',
    policy_week=policy_week,
    m=5,
    seasonal=True
)
